{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a3b8e999",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:23:51.523583Z",
     "start_time": "2024-05-22T01:23:41.768277Z"
    }
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "from IPython import display\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "from IPython.core.interactiveshell import InteractiveShell\n",
    "InteractiveShell.ast_node_interactivity=\"last_expr\"\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0581d893",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:23:53.734675Z",
     "start_time": "2024-05-22T01:23:53.728161Z"
    }
   },
   "outputs": [],
   "source": [
    "class GymHelper:\n",
    "    def __init__(self,env,figsize=(3,3)):\n",
    "        self.env=env\n",
    "        self.figsize=figsize\n",
    "        plt.figure(figsize=figsize)\n",
    "        plt.title(self.env.spec.id if hasattr(env.spec,\"id\") else \"\")\n",
    "        self.img=plt.imshow(env.render())\n",
    "    def render(self,title=None):\n",
    "        img_data=self.env.render()\n",
    "        self.img.set_data(img_data)\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "        if title:\n",
    "            plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b5e56a6d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:23:59.754424Z",
     "start_time": "2024-05-22T01:23:55.489115Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm import *\n",
    "import collections\n",
    "import time\n",
    "import random\n",
    "import sys\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7188059c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:24:01.674500Z",
     "start_time": "2024-05-22T01:24:01.670512Z"
    }
   },
   "outputs": [],
   "source": [
    "from gym import spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c4cd0cfc",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:47:57.741942Z",
     "start_time": "2024-05-22T01:47:57.716324Z"
    }
   },
   "outputs": [],
   "source": [
    "#继承gym.env环境\n",
    "class SnakeEnv(gym.Env):\n",
    "    def __init__(self,grid_size=16):\n",
    "        super(SnakeEnv,self).__init__()\n",
    "        #保存网格大小\n",
    "        self.grid_size=grid_size\n",
    "        #蛇的初始位置,设为中心点\n",
    "        self.snake=[(self.grid_size//2,self.grid_size//2)]\n",
    "        #定义动作空间\n",
    "        self.action_space=spaces.Discrete(4)\n",
    "        #定义状态空间,是一个grid*grid*3的三维空间,值域为0-1\n",
    "        self.observation_space=spaces.Box(low=0,high=1,shape=(grid_size,grid_size),dtype=np.int8)\n",
    "        #初始化食物\n",
    "        self.food=None\n",
    "        #蛇的总步数\n",
    "        self.steps=0\n",
    "        #继而程度\n",
    "        self.hunary=0\n",
    "        #上一步的动作,初始化一个特殊值即可\n",
    "        self.last_action=-10\n",
    "        #场上食物的数量\n",
    "        self.num_food=10\n",
    "        #游戏最大步数\n",
    "        self.max_steps=200\n",
    "    def reset(self):\n",
    "        #蛇回到初始位置\n",
    "        self.snake=[(self.grid_size//2,self.grid_size//2)]\n",
    "        #食物重新生成\n",
    "        self.food=self._generate_food()\n",
    "        #步数和继而度归零\n",
    "        self.steps=0\n",
    "        self.hunary=0\n",
    "        return self._get_state(),{}\n",
    "    def step(self,action):\n",
    "        #确定蛇头的位置\n",
    "        head=self.snake[0]\n",
    "        #如果蛇试图做和上一步相反的动作,那么保持不变\n",
    "        if(abs(action-self.last_action))==2:\n",
    "            action=self.last_action\n",
    "        #根据动作决定蛇头的新位置\n",
    "        if action==0:#上\n",
    "            new_head=(head[0]-1,head[1])\n",
    "        elif action==1:#右\n",
    "            new_head=(head[0],head[1]+1)\n",
    "        elif action==2:#下\n",
    "            new_head=(head[0]+1,head[1])\n",
    "        else:#左\n",
    "            new_head=(head[0],head[1]-1)\n",
    "        self.last_action=action\n",
    "        self.steps+=1\n",
    "        #蛇头出现在蛇身上,游戏结束\n",
    "        if new_head in self.snake:\n",
    "            return self._get_state(),-100,True,self.steps>=self.max_steps,{}\n",
    "        #蛇头新位置出界了,游戏结束\n",
    "        if self._is_out_of_bounds(new_head):\n",
    "            return self._get_state(),-200,True,self.steps>=self.max_steps,{}\n",
    "        elif new_head in self.food:\n",
    "            #在蛇头前面插入一节新蛇头\n",
    "            self.snake.insert(0,new_head)\n",
    "            #重新生成食物\n",
    "            self.food=self._generate_food()\n",
    "            self.hunary=0\n",
    "            return self._get_state(),50,False,self.steps>=self.max_steps,{}\n",
    "        else:\n",
    "            #最前面插入一节新蛇头的新位置\n",
    "            self.snake.insert(0,new_head)\n",
    "            #去掉最后一节\n",
    "            self.snake.pop()\n",
    "            #默认奖励为0,解读度增加\n",
    "            r=0\n",
    "            self.hunary+=1\n",
    "            if self.hunary>=20:\n",
    "                r-=(1+(self.hunary-20)/100)\n",
    "            #返回对应结果\n",
    "            return self._get_state,r,False,self.steps>=self.max_steps,{}\n",
    "    def sample(self):\n",
    "        return self.action_space.sample()\n",
    "    def render(self,mode=\"rgb_array\"):\n",
    "        img=np.zeros((self.grid_size,self.grid_size,3))\n",
    "        #绘制蛇体,蛇头是红色,蛇身是蓝色\n",
    "        for i,s in enumerate(self.snake):\n",
    "            if i==0:\n",
    "                img[s]=[1,0,0]\n",
    "            else:\n",
    "                img[s]=[0,0,1]\n",
    "        #食物为绿色\n",
    "        for f in self.food:\n",
    "            img[f]=[0,1,0]\n",
    "        return img\n",
    "    def _get_state(self):\n",
    "        self.render()\n",
    "    def _is_out_of_bounds(self,position):\n",
    "        x,y=position\n",
    "        return x<0 or y<0 or x>=self.grid_size or y>=self.grid_size\n",
    "    def _generate_food(self):\n",
    "        foods=[]\n",
    "        while 1:\n",
    "            food=(random.randint(0,self.grid_size-1),random.randint(0,self.grid_size-1))\n",
    "            if food not in self.snake and food not in foods:\n",
    "                foods.append(food)\n",
    "            if len(foods)>=self.num_food:\n",
    "                return foods"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4e297a5e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:48:23.271764Z",
     "start_time": "2024-05-22T01:48:20.924852Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAEnCAYAAAByolz0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYUklEQVR4nO3dfVAU5x0H8O+CsoCBS8DAccNLkKI0aEiV2miMSmOIxJC3NsGxQUw7GZ0hKmPbCdSkEKOQmMbYGUoZbavJJDGaaSFGaR0mIiSDrSgSHacTtUG5hlwJTnsHKIdyT/8wXnvhxTtuH3cPvx/nN5Pb29v9uTm+7gv7rCKEECAikiBI7waIaPxiwBCRNAwYIpKGAUNE0jBgiEgaBgwRScOAISJpGDBEJA0DhoikYcDcJHbv3o309HSEhYVBURQ89thjUBRlTMs6dOgQFEXBoUOHfPrcHXfcgYcffnhM6/TFuXPnoCgKdu7cKX1dNLoJejdA8n311VfIz8/H4sWLUVVVBVVVYbFYUFxcPKblzZw5E4cPH8add96pcac03jBgbgKnT5/G5cuX8fTTT2PBggXu6YmJiWNaXmRkJO655x6t2qNxjIdI49yKFSswb948AEBeXh4URcHChQtRVlY25BDp2iHMX/7yF8ycORNhYWFIS0vDH/7wB4/5hjtE+vzzz7F06VJYLBaoqorY2Fjcf//9aGtrG9LT9ZYPADabDStXrkR8fDxCQkKQnJyMl156CVeuXPGYr7OzE0899RQiIiJgMpmQl5cHm802xq1FWuMezDj34osvYvbs2SgsLER5eTmysrIQGRmJPXv2DDv/p59+ip/+9KcoLi5GbGwsfve73+EnP/kJvvWtb2H+/Pkjruehhx7C4OAgNm/ejMTERHR3d6O5uRn/+c9/fF6+zWbD7NmzERQUhF/+8pdISUnB4cOHsXHjRpw7dw47duwAAFy6dAmLFi1CZ2cnKioqMHXqVOzfvx95eXnabDzyn6Bxr6GhQQAQ77//vntaaWmp+Ob//qSkJBEaGirOnz/vnnbp0iURFRUlVq5cOWR5DQ0NQgghuru7BQCxdevWUfvwdvkrV64Ut9xyi8d8Qgjxq1/9SgAQp06dEkII8dvf/lYAEB988IHHfM8++6wAIHbs2DFqPyQfD5HIw9133+1xbiY0NBRTp07F+fPnR/xMVFQUUlJS8Nprr2HLli04fvw4XC7XmJe/b98+ZGVlwWKx4MqVK+7KyckBADQ2NgIAGhoaEBERgUceecRjHcuWLfP9L05SMGDIQ3R09JBpqqri0qVLI35GURR89NFHePDBB7F582bMnDkTt99+O9asWYOenh6fl/+vf/0LH374ISZOnOhR6enpAIDu7m4AwIULFxAbGztkeWaz2bu/LEnHczCkiaSkJPz+978HcPWq1Z49e1BWVoaBgQFUV1f7tKzJkyfjrrvuwqZNm4Z932KxALgaVkeOHBnyPk/yGgcDhjQ3depUvPDCC/jjH/+I1tZWnz//8MMPo66uDikpKbjttttGnC8rKwt79uzB3r17PQ6T3n333TH1TdpjwJDfTpw4geeeew5PPvkkUlNTERISgoMHD+LEiRNj+mW+DRs2oL6+HnPnzsWaNWswbdo09Pf349y5c6irq0N1dTXi4+OxfPlyvPHGG1i+fDk2bdqE1NRU1NXV4cCBAxL+ljQWDBjym9lsRkpKCqqqqmC1WqEoCqZMmYLXX38dq1ev9nl5cXFxOHr0KF5++WW89tpr+Oc//4mIiAgkJydj8eLF7r2a8PBwHDx4EGvXrkVxcTEURUF2djbee+89zJ07V+u/Jo2BIgSfKkBEcvAqEhFJw4AhImkYMEQkDQOGiKRhwBCRNAwYIpLGcL8H43K50NnZiYiIiDEP6UhE2hNCoKenBxaLBUFB3u2bGC5gOjs7kZCQoHcbRDQCq9WK+Ph4r+Y1XMBERETo3QJRYLD78VnT2D/qy8+o4QKGh0VEXorUZ7W+/IxKO8lbVVWF5ORkhIaGYtasWfj4449lrYqIDEpKwOzevRtFRUVYv349jh8/jvvuuw85OTno6OiQsToiMioZ43DOnj1brFq1ymNaWlqaKC4uvu5n7Xa7AMBisa5X/vzxY712u93rLNB8D2ZgYADHjh1Ddna2x/Ts7Gw0NzcPmd/pdMLhcHgUEY0PmgdMd3c3BgcHh4yVGhsbO+xQhhUVFTCZTO7iJWqi8UPaSd5vnmkWQgx79rmkpAR2u91dVqtVVktEdINpfpl68uTJCA4OHrK30tXVNewI8KqqQlVVrdsgIgPQfA8mJCQEs2bNQn19vcf0a2OsEtHNQ8ov2q1btw75+fnIzMzEnDlzsG3bNnR0dGDVqlUyVkdEBiUlYPLy8nDhwgVs2LABX375JaZPn466ujokJSXJWB0RGZThBv12OBwwmUx6t0FkfP785PpxR47dbkdkpHf3KRjuXiS/6LTBiXQRAN9ZDjhFRNIwYIhIGgYMEUnDgCEiaRgwRCQNA4aIpGHAEJE0DBgikoYBQ0TSMGCISBoGDBFJw4AhImkYMEQkDQOGiKQZX8M1BMDt60Q3E+7BEJE0DBgikkbzgKmoqMB3v/tdREREICYmBo899hg+++wzrVdDRAFA84BpbGxEYWEh/vrXv6K+vh5XrlxBdnY2+vr6tF4VERmc9EG/v/rqK8TExKCxsRHz58+/7vwc9JvI2Aw16LfdbgcAREVFDfu+0+mE0+l0v3Y4HLJbIqIbRUjkcrlEbm6umDdv3ojzlJaWClx9HgCLxQqAstvtXmeA1EOkwsJC7N+/H5988gni4+OHnWe4PZiEhARZLRGRnwxxiLR69Wrs3bsXTU1NI4YLAKiqClVVZbVBRDrSPGCEEFi9ejVqampw6NAhJCcna70KIgoQmgdMYWEh3n33XXzwwQeIiIiAzWYDAJhMJoSFhWm9OiIyMM3PwSjK8DcE7dixAytWrLju53mZmsjYdD0HI/GcMREFmPF1N3WgGWsW865xChC82ZGIpGHAEJE0DBgikoYBQ0TSMGCISBoGDBFJw4AhImkYMEQkDQOGiKRhwBCRNAwYIpKGAUNE0jBgiEgaBgwRScPhGvTEYRdonOMeDBFJw4AhImmkB0xFRQUURUFRUZHsVRGRwUgNmJaWFmzbtg133XWXzNUQkUFJC5je3l786Ec/wvbt23HbbbfJWg0RGZi0gCksLMSSJUuwaNGiUedzOp1wOBweRUTjg5TL1O+99x5aW1vR0tJy3XkrKirw0ksvyWiDiHSm+R6M1WrF2rVr8fbbbyM0NPS685eUlMBut7vLarVq3RIR6UTzJzvW1tbi8ccfR3BwsHva4OAgFEVBUFAQnE6nx3vfxCc7Ehmbrk92vP/++3Hy5EmPac888wzS0tLw/PPPjxouRDS+aB4wERERmD59use0SZMmITo6esh0Ihrf+Ju8RCSN5udg/MVzMETGpus5mIDFB9ETaY6HSEQkDQOGiKRhwBCRNAwYIpKGAUNE0jBgiEgaBgwRScOAISJpGDBEJA0DhoikYcAQkTQMGCKShgFDRNIwYIhIGuMO12AH4N2QE//jz9AJHHaBSHPcgyEiaRgwRCSNlID54osv8PTTTyM6Ohrh4eG4++67cezYMRmrIiID0/wczL///W/ce++9yMrKwp///GfExMTgH//4B2699VatV0VEBqd5wLz66qtISEjAjh073NPuuOMOrVdDRAFA80OkvXv3IjMzE08++SRiYmLwne98B9u3bx9xfqfT6fHge4fDoXVLRKQXoTFVVYWqqqKkpES0traK6upqERoaKt58881h5y8tLRW4Oqa/Z9khfP4z3HJYLJamZbfbvc4DzZ+LFBISgszMTDQ3N7unrVmzBi0tLTh8+PCQ+Z1OJ5xOp/u1w+FAQkLCjf89GCLyii/PRdL8ECkuLg533nmnx7Rvf/vb6OjoGHZ+VVURGRnpUUQ0PmgeMPfeey8+++wzj2mnT59GUlKS1qsiIqPT5szL/xw5ckRMmDBBbNq0SZw5c0a88847Ijw8XLz99ttefd5ut1891uM5GBbLkOXLORjNA0YIIT788EMxffp0oaqqSEtLE9u2bfP6swwYFsvYpetJXn85HA6YTCae5CUyKF9O8hr3bmqT3g14yZ94ZiDSOMebHYlIGgYMEUnDgCEiaRgwRCQNA4aIpGHAEJE0DBgikoYBQ0TSMGCISBoGDBFJw4AhImkYMEQkDQOGiKRhwBCRNAwYIpKGAUNE0jBgiEgazQPmypUreOGFF5CcnIywsDBMmTIFGzZsgMvl0npVRGR0vg/pPbqNGzeK6OhosW/fPtHe3i7ef/99ccstt4itW7d69Xn3oN+BUv780bt3FmsM5cug35qPyXv48GE8+uijWLJkCYCrD77ftWsXjh49qvWqiMjgND9EmjdvHj766COcPn0aAPDpp5/ik08+wUMPPTTs/E6n0+PB9w6HQ+uWiEgvYz0UGonL5RLFxcVCURQxYcIEoSiKKC8vH3H+0tJS3Xf5/CoeIrFustL1wWu7du0S8fHxYteuXeLEiRPirbfeElFRUWLnzp3Dzt/f3y/sdru7rFar7hvQp2LAsG6y0jVg4uPjRWVlpce0l19+WUybNs2rz/MkL4tl7PIlYDQ/B3Px4kUEBXkuNjg4mJepiW5Cml9Fys3NxaZNm5CYmIj09HQcP34cW7ZswY9//GOtV0VERufT8Y8XHA6HWLt2rUhMTBShoaFiypQpYv369cLpdHr1eR4isVjGLl8OkRQhhICBOBwOmEwmvdvwnj9bj8+mpgBkt9sRGRnp1byaHyLddPQICYYaBQje7EhE0jBgiEgaBgwRScOAISJpGDBEJA0DhoikYcAQkTQMGCKShgFDRNIwYIhIGgYMEUnDgCEiaRgwRCQN76YORLwjmgIE92CISBoGDBFJ43PANDU1ITc3FxaLBYqioLa21uN9IQTKyspgsVgQFhaGhQsX4tSpU1r1S0QBxOeA6evrQ0ZGBiorK4d9f/PmzdiyZQsqKyvR0tICs9mMBx54AD09PX43S0QBZkwje38NgKipqXG/drlcwmw2i1deecU9rb+/X5hMJlFdXe3VMgNu0G8W6yYr3Z6L1N7eDpvNhuzsbPc0VVWxYMECNDc3a7kqIgoAml6mttlsAIDY2FiP6bGxsTh//vywn3E6nXA6ne7XDodDy5aISEdSriIpiucvagghhky7pqKiAiaTyV0JCQkyWiIiHWgaMGazGcD/9mSu6erqGrJXc01JSQnsdru7rFarli0RkY40DZjk5GSYzWbU19e7pw0MDKCxsRFz584d9jOqqiIyMtKjiGh88PkcTG9vL86ePet+3d7ejra2NkRFRSExMRFFRUUoLy9HamoqUlNTUV5ejvDwcCxbtkzTxokoAHh9velrDQ0Nw166KigoEEJcvVRdWloqzGazUFVVzJ8/X5w8edLr5fMyNYtl7OKzqYlIGl+eTc17kYhIGg7X4LcbvwMo/BivgSM90I3EPRgikoYBQ0TSMGCISBoGDBFJw4AhImkYMEQkDQOGiKRhwBCRNAwYIpKGAUNE0jBgiEgaBgwRScOAISJpeDe13278/cm8I/o6/LnBnRtXU9yDISJpGDBEJI3PAdPU1ITc3FxYLBYoioLa2lr3e5cvX8bzzz+PGTNmYNKkSbBYLFi+fDk6Ozu17JmIAoTPAdPX14eMjAxUVlYOee/ixYtobW3Fiy++iNbWVvzpT3/C6dOn8cgjj2jSLBEFGN+eKeAJgKipqRl1niNHjggA4vz5814tk08VYPld/vzRu/cAKF+eKiD9HIzdboeiKLj11ltlr4qIDEbqZer+/n4UFxdj2bJlIz7mwOl0wul0ul87HA6ZLRHRDSRtD+by5ctYunQpXC4XqqqqRpyvoqLC/eB7k8mEhIQEWS0R0Q0mJWAuX76Mp556Cu3t7aivrx/1IU0lJSXuB9/b7XZYrVYZLRGRDjQ/RLoWLmfOnEFDQwOio6NHnV9VVaiqqnUbRGQAPgdMb28vzp49637d3t6OtrY2REVFwWKx4Ic//CFaW1uxb98+DA4OwmazAQCioqIQEhKiXedEZHxeX2/6WkNDw7CXrgoKCkR7e/uIl7YaGhp4mZp1Y4qXqaWWL5epFSGEgIE4HA6YTCa926BA5s83mjc7Xpfdbh/1vOr/471IRCQNh2vwF/+1NB5uV8PgHgwRScOAISJpGDBEJA0DhoikYcAQkTQMGCKShgFDRNIwYIhIGgYMEUnDgCEiaRgwRCQNA4aIpGHAEJE0vJvaX7xzl2hE3IMhImkYMEQkjc8B09TUhNzcXFgsFiiKgtra2hHnXblyJRRFwdatW/1okYgClc8B09fXh4yMDFRWVo46X21tLf72t7/BYrGMuTkiCmw+n+TNyclBTk7OqPN88cUXeO6553DgwAEsWbJkzM0RUWDT/ByMy+VCfn4+fv7znyM9PV3rxRNRANH8MvWrr76KCRMmYM2aNV7N73Q64XQ63a8dDofWLRGRTjTdgzl27Bh+/etfY+fOnVAU735BpKKiwv3ge5PJhISEBC1bIiI9+fpkx/8HQNTU1Lhfv/HGG0JRFBEcHOwuACIoKEgkJSUNu4z+/n5ht9vdZbVadX9yHYvFGrl8ebKjpodI+fn5WLRokce0Bx98EPn5+XjmmWeG/YyqqlBVVcs2iMggfA6Y3t5enD171v26vb0dbW1tiIqKQmJiIqKjoz3mnzhxIsxmM6ZNm+Z/t0QUUHwOmKNHjyIrK8v9et26dQCAgoIC7Ny5U7PGiCjwKV+fSzEMh8MBk8mkdxtENAK73Y7IyEiv5uW9SEQkDQOGiKRhwBCRNAwYIpKGAUNE0jBgiEgaBgwRScOAISJpGDBEJA0DhoikYcAQkTQMGCKShgFDRNIwYIhIGsMFjMFGjyCib/DlZ9RwAdPT06N3C0Q0Cl9+Rg034JTL5UJnZyciIiKGfTKBw+FAQkICrFar14Pe3Ey4fUbH7TO60baPEAI9PT2wWCwICvJu30Tz5yL5KygoCPHx8dedLzIykl+QUXD7jI7bZ3QjbR9fR5s03CESEY0fDBgikibgAkZVVZSWlvJZSiPg9hkdt8/otN4+hjvJS0TjR8DtwRBR4GDAEJE0DBgikoYBQ0TSBFTAVFVVITk5GaGhoZg1axY+/vhjvVsyjLKyMiiK4lFms1nvtnTT1NSE3NxcWCwWKIqC2tpaj/eFECgrK4PFYkFYWBgWLlyIU6dO6dOsDq63fVasWDHk+3TPPff4vJ6ACZjdu3ejqKgI69evx/Hjx3HfffchJycHHR0derdmGOnp6fjyyy/ddfLkSb1b0k1fXx8yMjJQWVk57PubN2/Gli1bUFlZiZaWFpjNZjzwwAM3zb1w19s+ALB48WKP71NdXZ3vKxIBYvbs2WLVqlUe09LS0kRxcbFOHRlLaWmpyMjI0LsNQwIgampq3K9dLpcwm83ilVdecU/r7+8XJpNJVFdX69Chvr65fYQQoqCgQDz66KN+Lzsg9mAGBgZw7NgxZGdne0zPzs5Gc3OzTl0Zz5kzZ2CxWJCcnIylS5fi888/17slQ2pvb4fNZvP4PqmqigULFvD79H8OHTqEmJgYTJ06Fc8++yy6urp8XkZABEx3dzcGBwcRGxvrMT02NhY2m02nrozle9/7Ht566y0cOHAA27dvh81mw9y5c3HhwgW9WzOca98Zfp9GlpOTg3feeQcHDx7E66+/jpaWFnz/+9+H0+n0aTmGu5t6NN8cvkEIMeyQDjejnJwc93/PmDEDc+bMQUpKCt58802sW7dOx86Mi9+nkeXl5bn/e/r06cjMzERSUhL279+PJ554wuvlBMQezOTJkxEcHDzkX5eurq4h/wrRVZMmTcKMGTNw5swZvVsxnGtX1/h98l5cXBySkpJ8/j4FRMCEhIRg1qxZqK+v95heX1+PuXPn6tSVsTmdTvz9739HXFyc3q0YTnJyMsxms8f3aWBgAI2Njfw+jeDChQuwWq0+f58C5hBp3bp1yM/PR2ZmJubMmYNt27aho6MDq1at0rs1Q/jZz36G3NxcJCYmoqurCxs3boTD4UBBQYHeremit7cXZ8+edb9ub29HW1sboqKikJiYiKKiIpSXlyM1NRWpqakoLy9HeHg4li1bpmPXN85o2ycqKgplZWX4wQ9+gLi4OJw7dw6/+MUvMHnyZDz++OO+rcjv61A30G9+8xuRlJQkQkJCxMyZM0VjY6PeLRlGXl6eiIuLExMnThQWi0U88cQT4tSpU3q3pZuGhgYBYEgVFBQIIa5eqi4tLRVms1moqirmz58vTp48qW/TN9Bo2+fixYsiOztb3H777WLixIkiMTFRFBQUiI6ODp/Xw+EaiEiagDgHQ0SBiQFDRNIwYIhIGgYMEUnDgCEiaRgwRCQNA4aIpGHAEJE0DBgikoYBQ0TSMGCISBoGDBFJ81814KPAfh29LAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "env=SnakeEnv()\n",
    "env.reset()\n",
    "gym_helper=GymHelper(env)\n",
    "for i in range(20):\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=env.action_space.sample()\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    if terminated or truncated:\n",
    "        break\n",
    "gym_helper.render(title=\"finished\")\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "77a74ae3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T01:57:46.583327Z",
     "start_time": "2024-05-22T01:57:46.570490Z"
    }
   },
   "outputs": [],
   "source": [
    "#策略模型，给定状态生成各个动作的概率\n",
    "class Policymodel(nn.Module):\n",
    "    def __init__(self,grid_size,output_dim):\n",
    "        super(Policymodel,self).__init__()\n",
    "        self.grid_size=grid_size\n",
    "        self.output_dim=output_dim\n",
    "        self.conv=nn.Sequential(\n",
    "            nn.Conv2d(3,8,kernel_size=1),\n",
    "            nn.BatchNorm2d(8),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(8,4,kernel_size=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc=nn.Sequential(\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(grid_size*grid_size*4,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,self.output_dim),\n",
    "            nn.Softmax(dim=1)\n",
    "        )\n",
    "        #dueling networks\n",
    "    def forward(self,state):\n",
    "        x=state.view(-1,3,self.grid_size,self.grid_size)\n",
    "        x=self.conv(x)\n",
    "        action_prob=self.fc(x)\n",
    "        return action_prob\n",
    "#价值模型，给定状态的估计值\n",
    "class Valuemodel(nn.Module):\n",
    "    def __init__(self,grid_size):\n",
    "        super(Valuemodel,self).__init__()\n",
    "        self.grid_size=grid_size\n",
    "        #self.output_dim=output_dim\n",
    "        self.conv=nn.Sequential(\n",
    "            nn.Conv2d(3,8,kernel_size=1),\n",
    "            nn.BatchNorm2d(8),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(8,4,kernel_size=1),\n",
    "            nn.BatchNorm2d(4),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.fc=nn.Sequential(\n",
    "            nn.Linear(grid_size*grid_size*4,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,128),\n",
    "            nn.ReLU(),\n",
    "#             nn.Linear(128,self.output_dim),\n",
    "            nn.Linear(128,1)\n",
    "        )\n",
    "        #dueling networks\n",
    "    def forward(self,x):\n",
    "        x=x.view(-1,3,self.grid_size,self.grid_size)\n",
    "        x=self.conv(x)\n",
    "        value=self.fc(x)\n",
    "        return value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "eed8d1c6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T02:04:16.546591Z",
     "start_time": "2024-05-22T02:04:16.527524Z"
    }
   },
   "outputs": [],
   "source": [
    "class PPO:\n",
    "    def __init__(self,env,lr=0.001,gamma=0.99,lamda=0.95,eps=0.2,epochs=20):\n",
    "        self.env=env\n",
    "        self.lr=lr\n",
    "        self.gamma=gamma\n",
    "        self.lamda=lamda\n",
    "        self.eps=eps\n",
    "        self.epochs=epochs\n",
    "        #判断可用设备是CPU与GPU\n",
    "        self.device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        #定义策略网络与价值网络\n",
    "        self.policy_model=Policymodel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.value_model=Valuemodel(env.observation_space.shape[0]).to(self.device)\n",
    "        self.policy_optimizer=torch.optim.Adam(self.policy_model.parameters(),lr=lr)\n",
    "        self.value_optimizer=torch.optim.Adam(self.value_model.parameters(),lr=lr)\n",
    "    def choose_action(self,state):\n",
    "        #state=state.astype(float)\n",
    "        state=torch.tensor(np.array(state),dtype=torch.float32).to(self.device)\n",
    "        with torch.no_grad():\n",
    "            action_prob=self.policy_model(state)\n",
    "        c=torch.distributions.Categorical(action_prob)\n",
    "        action=c.sample()\n",
    "        return action\n",
    "    def calc_advantage(self,td_delta):\n",
    "        td_delta=td_delta.cpu().detach().numpy()\n",
    "        #初始化\n",
    "        advantage=0\n",
    "        advantage_list=[]\n",
    "        for r in td_delta[::-1]:\n",
    "            #将上一步的TDerror和上一步的优势加权为当前的优势\n",
    "            advantage+=r+self.gamma*self.lamda\n",
    "            #将优势值加到列表开头,最终得到顺序序列\n",
    "            advantage_list.insert(0,advantage)\n",
    "        return torch.FloatTensor(np.array(advantage_list)).to(self.device)\n",
    "    def update(self,batch):\n",
    "        states,actions,rewards,next_states,dones=zip(*batch)\n",
    "        states=torch.FloatTensor(np.array(states)).to(self.device)\n",
    "        actions=torch.FloatTensor(np.array(actions)).view(-1,1).to(self.device)\n",
    "        rewards=torch.FloatTensor(np.array(rewards)).view(-1,1).to(self.device)\n",
    "        next_states=torch.FloatTensor(np.array(next_states)).to(self.device)\n",
    "        dones=torch.FloatTensor(np.array(dones)).view(-1,1).to(self.device)\n",
    "        with torch.no_grad():\n",
    "            #计算就动作状态下的策略概率\n",
    "            old_action_prob=torch.log(self.policy_model(states).gather(1,actions))\n",
    "            #计算TD目标以及误差\n",
    "            tf_target=rewards+(1-dones)*self.gamma*self.value_model(next_states)\n",
    "            td_delta=td_target-self.value_model(states)\n",
    "        #优势估计\n",
    "        advantage=self.calc_advantage(td_delta)\n",
    "        for i in range(self.epochs):\n",
    "            #计算策略下的动作概率\n",
    "            actions=actions.type(torch.long).to(self.device)\n",
    "            action_prob=torch.log(self.policy_model(states).gather(1,actions))\n",
    "            #计算策略动作概率比\n",
    "            ratio=torch.exp(action_prob-old_action_prob)\n",
    "            #clip修剪\n",
    "            part1=ratio*advantage\n",
    "            part2=torch.clamp(ratio,1-self.clip_eps,1+self.clip_eps)*advantage\n",
    "            #计算策略损失\n",
    "            policy_loss=torch.min(part1,part2).mean()\n",
    "            #计算价值损失\n",
    "            value_loss=F.mse_loss(self.value_model(states),td_target).mean()\n",
    "            \n",
    "            #梯度清零,反向传播,参数更新\n",
    "            self.policy_optimizer.zero_grad()\n",
    "            self.value_optimizer.zero_grad()\n",
    "            policy_loss.backward()\n",
    "            value_loss.backward()\n",
    "            self.policy_optimizerl.step()\n",
    "            self.value_optimizer.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "0ee3413c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T02:04:17.102264Z",
     "start_time": "2024-05-22T02:04:17.065092Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                         | 0/1000 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_10004\\2851397216.py\u001b[0m in \u001b[0;36m<cell line: 8>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     11\u001b[0m     \u001b[0mbuffer\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     12\u001b[0m     \u001b[1;32mfor\u001b[0m \u001b[0mstep\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmax_steps\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 13\u001b[1;33m         \u001b[0maction\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0magent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mchoose_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     14\u001b[0m         \u001b[0mnext_state\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mterminated\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtruncated\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maction\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     15\u001b[0m         \u001b[0mdone\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mterminated\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mtruncated\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_10004\\428130858.py\u001b[0m in \u001b[0;36mchoose_action\u001b[1;34m(self, state)\u001b[0m\n\u001b[0;32m     16\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mchoose_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     17\u001b[0m         \u001b[1;31m#state=state.astype(float)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 18\u001b[1;33m         \u001b[0mstate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     19\u001b[0m         \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m             \u001b[0maction_prob\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."
     ]
    }
   ],
   "source": [
    "max_episodes=1000\n",
    "max_steps=200\n",
    "#batch_size=32\n",
    "\n",
    "agent=PPO(env)\n",
    "eps_rewards=[]\n",
    "\n",
    "for episode in tqdm(range(max_episodes)):\n",
    "    state,_=env.reset()\n",
    "    eps_reward=0\n",
    "    buffer=[]\n",
    "    for step in range(max_steps):\n",
    "        action=agent.choose_action(state)\n",
    "        next_state,reward,terminated,truncated,info=env.step(action)\n",
    "        done=terminated or truncated\n",
    "        buffer.append((state,action,reward,next_state,done))\n",
    "        eps_reward+=reward\n",
    "#         if len(agent.replay_buffer)>batch_size:\n",
    "#             agent.update(batch_size)\n",
    "        state=next_state\n",
    "        if done:\n",
    "            break\n",
    "    agent.update(buffer)\n",
    "    eps_rewards.append(eps_reward)\n",
    "    if episode % 40==0:\n",
    "        tqdm.write(\"Episode\"+str(episode)+\":\"+str(eps_reward))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c3cef6d4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-22T02:05:44.663612Z",
     "start_time": "2024-05-22T02:05:44.379723Z"
    }
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_10004\\1657158831.py\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m20\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[0mgym_helper\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtitle\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m     \u001b[0maction\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0magent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mchoose_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobservation\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      7\u001b[0m     \u001b[0mobservation\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mterminated\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtruncated\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0maction\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      8\u001b[0m     \u001b[0mdone\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mterminated\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mtruncated\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_10004\\428130858.py\u001b[0m in \u001b[0;36mchoose_action\u001b[1;34m(self, state)\u001b[0m\n\u001b[0;32m     16\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mchoose_action\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     17\u001b[0m         \u001b[1;31m#state=state.astype(float)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 18\u001b[1;33m         \u001b[0mstate\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0marray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     19\u001b[0m         \u001b[1;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     20\u001b[0m             \u001b[0maction_prob\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpolicy_model\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool."
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARgAAAEnCAYAAAByolz0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVmElEQVR4nO3df2xVd/3H8ddtobfA2jtbbG9v+sNO2Top65S6CdsY01GtpE7QCaK1w8SIwW1NoxkVTXFuLVsizqQSAn8MFjY2jbbODV0aZWULc+v4oUiWAVrXm3VNXWfupSW9/Ojn+8fk+r2jLb30fHrOLc8HeSfcc8+9553D6Yvzo+dzfMYYIwCwIM3tBgBMXwQMAGsIGADWEDAArCFgAFhDwACwhoABYA0BA8AaAgaANQQMrBgcHFR9fb1CoZAyMzN144036umnn3a7LUyxGW43gOlp5cqV6urq0ubNm3Xttdfqqaee0te+9jWNjIxozZo1breHKeLjXiQ4be/evVq+fHk8VC6oqqrSsWPH1NPTo/T0dBc7xFThEAmOa2tr01VXXaW77747YfratWvV29urV1991aXOMNUIGDju73//u66//nrNmJF4BH7DDTfE38eVgYCB4wYGBpSTk3PR9AvTBgYGproluISAgRU+n++y3sP0QsDAcbm5uaPupbz33nuSNOreDaYnAgaOW7Bggd544w2dO3cuYfrRo0clSeXl5W60BRcQMHDcihUrNDg4qN/85jcJ03ft2qVQKKSbb77Zpc4w1fhFOziuurpay5Yt03e/+11Fo1F97GMf0549e/THP/5Ru3fv5ndgriD8oh2sGBwc1MaNG/WrX/1K7733nsrKytTY2KjVq1e73RqmEAEDwBrOwQCwhoABYA0BA8AaAgaANQQMAGsIGADWeO4X7UZGRtTb26usrCxuigM8xBijU6dOKRQKKS1tYvsmnguY3t5eFRUVud0GgDGEw2EVFhZOaF7PBUxWVpbbLcArIpf5uYCjXeADkvkZ9VzAcFiEuGy3G8BokvkZtXaSd+vWrSotLVVmZqYWLlyol156ydaiAHiUlYB55plnVF9fr40bN+rw4cO67bbbVF1drZ6eHhuLA+BVxoKbbrrJrFu3LmFaWVmZ2bBhwyU/G4lEjCSKMpf9x+2+p3lFIpEJZ4HjezBnzpzRwYMHVVVVlTC9qqpKBw4cuGj+WCymaDSaUACmB8cD5t1339X58+eVn5+fMD0/P199fX0Xzd/S0qJAIBAvLlED04e1k7wfPNNsjBn17HNjY6MikUi8wuGwrZYATDHHL1PPnTtX6enpF+2t9Pf3X7RXI0l+v19+v9/pNgB4gON7MBkZGVq4cKE6OjoSpnd0dGjx4sVOLw6Ah1n5RbuGhgbV1taqsrJSixYt0vbt29XT06N169bZWBwAj7ISMKtWrdLAwIAefPBBvfPOOyovL9fevXtVUlJiY3EAPMpzg35Ho1EFAgG324AXXO6Wyd0mVkUiEWVnT+w+Ds/diwTEXSlBMY2DlAGnAFhDwACwhoABYA0BA8AaAgaANQQMAGsIGADWEDAArCFgAFhDwACwhoABYA0BA8AaAgaANQQMAGsYrgFwwmRGVUqBYRcuF3swAKwhYABY43jAtLS06FOf+pSysrKUl5enL33pS3rzzTedXgyAFOB4wHR2dmr9+vX6y1/+oo6ODp07d05VVVUaGhpyelEAPM76oN///ve/lZeXp87OTi1ZsuSS8zPoN1LSFXSS11ODfkciEUlSTk7OqO/HYjHFYrH462g0arslAFPE6kleY4waGhp06623qry8fNR5Wlpa4g++DwQCKioqstkSgClk9RBp/fr1ev755/Xyyy+rsLBw1HlG24MhZJByOEQalbVDpHvvvVfPPvus9u/fP2a4SJLf75ff77fVBgAXOR4wxhjde++9amtr04svvqjS0lKnFwEgRTgeMOvXr9dTTz2l3/3ud8rKylJfX58kKRAIaNasWU4vDoCHOX4Oxucb/YDy8ccf1z333HPJz3OZGimJczCjsnKIBAASd1MDzkixvZCpws2OAKwhYABYQ8AAsIaAAWANAQPAGgIGgDUEDABrCBgA1hAwAKwhYABYQ8AAsIaAAWANAQPAGgIGgDUEDABrCBgA1hAwAKyxHjAtLS3y+Xyqr6+3vSgAHmM1YLq6urR9+3bdcMMNNhcDwKOsBczg4KC+/vWva8eOHfrQhz5kazEAPMxawKxfv17Lly/XnXfeOe58sVhM0Wg0oQBMD1aeKvD000/r0KFD6urquuS8LS0t+slPfmKjDQAuc3wPJhwO6/7779fu3buVmZl5yfkbGxsViUTiFQ6HnW4JgEscf7Jje3u7VqxYofT09Pi08+fPy+fzKS0tTbFYLOG9D+LJjoC3ufpkx89+9rM6evRowrS1a9eqrKxMDzzwwLjhAmB6cTxgsrKyVF5enjBtzpw5ys3NvWg6gOmN3+QFYI3j52Ami3MwgLclcw6GPRgA1hAwAKwhYABYQ8AAsIaAAWANAQPAGgIGgDUEDABrCBgA1hAwAKwhYABYQ8AAsIaAAWANAQPAGgIGgDVWnioAxE1mtCGfY13AJezBALCGgAFgjZWAefvtt/WNb3xDubm5mj17tm688UYdPHjQxqIAeJjj52D+85//6JZbbtEdd9yhP/zhD8rLy9M//vEPXX311U4vCoDHOR4wjzzyiIqKivT444/Hp33kIx9xejEAUoDjh0jPPvusKisrdffddysvL0+f+MQntGPHjjHnj8ViCQ++j0ajTrcEwC3GYX6/3/j9ftPY2GgOHTpktm3bZjIzM82uXbtGnb+pqcno/YuZ1HSsyfxxu3dq1IpEIhPOA8efi5SRkaHKykodOHAgPu2+++5TV1eXXnnllYvmj8ViisVi8dfRaFRFRUVOtgQ3TWbr4vdgPMnV5yIVFBTo4x//eMK066+/Xj09PaPO7/f7lZ2dnVAApgfHA+aWW27Rm2++mTDt+PHjKikpcXpRALzOmTMv//Paa6+ZGTNmmIcffticOHHCPPnkk2b27Nlm9+7dE/p8JBJx/RiTcrA4BzPtKplzMI4HjDHG/P73vzfl5eXG7/ebsrIys3379gl/loCZZkXATLty9STvZEWjUQUCAbfbgFMms3VxkteTkjnJy93UsGsyIXG54UQweQY3OwKwhoABYA0BA8AaAgaANQQMAGsIGADWEDAArCFgAFhDwACwhoABYA0BA8AaAgaANQQMAGu4mxpWTWq0Bu6KHl8KDIXBHgwAawgYANYQMACscTxgzp07px/96EcqLS3VrFmzdM011+jBBx/UyMiI04sC4HXJD+k9voceesjk5uaa5557znR3d5tf//rX5qqrrjKPPfbYhD7PoN/Tq8wkyu3ePV8uDaiezKDfjl9FeuWVV3TXXXdp+fLlkt5/8P2ePXv0+uuvO70oAB7n+CHSrbfeqj/96U86fvy4JOmvf/2rXn75ZX3hC18Ydf5YLJbw4PtoNOp0SwDccrmHQmMZGRkxGzZsMD6fz8yYMcP4fD7T3Nw85vxNTU3u72pS1spMotzu3fOVAodIjgfMnj17TGFhodmzZ4/529/+Zp544gmTk5Njdu7cOer8w8PDJhKJxCscDrv/D0c5VmYS5Xbvnq8rMWAKCwtNa2trwrSf/vSn5rrrrpvQ5znJO73KTKLc7t3zlQIB4/g5mNOnTystLfFr09PTuUwNXIEcv4pUU1Ojhx9+WMXFxZo/f74OHz6sLVu26Fvf+pbTiwLgdUkd/0xANBo1999/vykuLjaZmZnmmmuuMRs3bjSxWGxCn+cQaXqVmUS53bvnKwUOkXzGGCMPiUajCgQCbrcBh0xm4+Jm6ktwaeVGIhFlZ2dPaF6Ga4BVhIRFKbByudkRgDUEDABrCBgA1hAwAKwhYABYQ8AAsIaAAWANAQPAGgIGgDUEDABrCBgA1hAwAKwhYABY4927qSOSJnZH+P+kwN2lmKYYl2JU7MEAsIaAAWBN0gGzf/9+1dTUKBQKyefzqb29PeF9Y4w2bdqkUCikWbNmaenSpTp27JhT/QJIIUkHzNDQkCoqKtTa2jrq+48++qi2bNmi1tZWdXV1KRgMatmyZTp16tSkmwWQYi5rZO//kmTa2trir0dGRkwwGDSbN2+OTxseHjaBQMBs27ZtQt8ZH/Q7MrUDGVPUpMqlAbjdKNeei9Td3a2+vj5VVVXFp/n9ft1+++06cOCAk4sCkAIcvUzd19cnScrPz0+Ynp+fr7feemvUz8RiMcVisfjraDTqZEsAXGTlKpLPl3hh3xhz0bQLWlpaFAgE4lVUVGSjJQAucDRggsGgpP/tyVzQ399/0V7NBY2NjYpEIvEKh8NOtgTARY4GTGlpqYLBoDo6OuLTzpw5o87OTi1evHjUz/j9fmVnZycUgOkh6XMwg4ODOnnyZPx1d3e3jhw5opycHBUXF6u+vl7Nzc2aN2+e5s2bp+bmZs2ePVtr1qxxtHEAKWDC15v+a9++faNeuqqrqzPGvH+puqmpyQSDQeP3+82SJUvM0aNHJ/z9XKamUrK4TD0q7z6bmpsdkUom81OUYtttMs+m5l4kANZ4d7iGgNsNAElIsb2QqcIeDABrCBgA1hAwAKwhYABYQ8AAsIaAAWANAQPAGgIGgDUEDABrCBgA1hAwAKwhYABYQ8AAsIaAAa5ElzPUVCT5xRAwAKwhYABYk3TA7N+/XzU1NQqFQvL5fGpvb4+/d/bsWT3wwANasGCB5syZo1AopG9+85vq7e11smcAKSLpgBkaGlJFRYVaW1sveu/06dM6dOiQfvzjH+vQoUP67W9/q+PHj+uLX/yiI80CSDHJPVMgkSTT1tY27jyvvfaakWTeeuutCX1n/KkCFEXZq8v5E3n/s8k8VcD6OZhIJCKfz6err77a9qIAeIzVQb+Hh4e1YcMGrVmzZszHHMRiMcVisfjraDRqsyUAU8jaHszZs2e1evVqjYyMaOvWrWPO19LSEn/wfSAQUFFRka2WAEwxKwFz9uxZffWrX1V3d7c6OjrGfUhTY2Nj/MH3kUhE4XDYRksAXOD4IdKFcDlx4oT27dun3Nzccef3+/3y+/1OtwHAA5IOmMHBQZ08eTL+uru7W0eOHFFOTo5CoZC+8pWv6NChQ3ruued0/vx59fX1SZJycnKUkZHhXOcAvG/C15v+a9++faNe9qqrqzPd3d1jXhbbt28fl6kpyis1RZepk96DWbp0qYwxY74/3nsArizciwTAGqu/BwNgAi53p983iWVO5rNJYA8GgDUEDABrCBgA1hAwAKwhYABYQ8AAsIaAAWANAQPAGgIGgDUEDABrCBgA1hAwAKwhYABYw93UgNum6M5mN7AHA8AaAgaANUkHzP79+1VTU6NQKCSfz6f29vYx5/3Od74jn8+nxx57bBItAkhVSQfM0NCQKioq1NraOu587e3tevXVVxUKhS67OQCpLemTvNXV1aqurh53nrffflvf+9739MILL2j58uWX3RyA1Ob4OZiRkRHV1tbqBz/4gebPn+/01wNIIY5fpn7kkUc0Y8YM3XfffROaPxaLKRaLxV9Ho1GnWwLgEkf3YA4ePKhf/OIX2rlzp3y+iV3cb2lpiT/4PhAIqKioyMmWALgp2Sc7/n+STFtbW/z1z3/+c+Pz+Ux6enq8JJm0tDRTUlIy6ncMDw+bSCQSr3A47P5T7yiKGrOsPtlxPLW1tbrzzjsTpn3uc59TbW2t1q5dO+pn/H6//H6/k20A8IikA2ZwcFAnT56Mv+7u7taRI0eUk5Oj4uJi5ebmJsw/c+ZMBYNBXXfddZPvFkBKSTpgXn/9dd1xxx3x1w0NDZKkuro67dy507HGAKQ+n/HY0+qj0agCgYDbbQAYQyQSUXZ29oTm5V4kANYQMACsIWAAWEPAALCGgAFgDQEDwBoCBoA1BAwAawgYANYQMACsIWAAWEPAALCGgAFgDQEDwBrPBYzHRo8A8AHJ/Ix6LmBOnTrldgsAxpHMz6jnBpwaGRlRb2+vsrKyRn0yQTQaVVFRkcLh8IQHvbmSsH7Gx/oZ33jrxxijU6dOKRQKKS1tYvsmjj8XabLS0tJUWFh4yfmys7PZQMbB+hkf62d8Y62fZEeb9NwhEoDpg4ABYE3KBYzf71dTUxPPUhoD62d8rJ/xOb1+PHeSF8D0kXJ7MABSBwEDwBoCBoA1BAwAa1IqYLZu3arS0lJlZmZq4cKFeumll9xuyTM2bdokn8+XUMFg0O22XLN//37V1NQoFArJ5/Opvb094X1jjDZt2qRQKKRZs2Zp6dKlOnbsmDvNuuBS6+eee+65aHv69Kc/nfRyUiZgnnnmGdXX12vjxo06fPiwbrvtNlVXV6unp8ft1jxj/vz5euedd+J19OhRt1tyzdDQkCoqKtTa2jrq+48++qi2bNmi1tZWdXV1KRgMatmyZVfMvXCXWj+S9PnPfz5he9q7d2/yCzIp4qabbjLr1q1LmFZWVmY2bNjgUkfe0tTUZCoqKtxuw5Mkmba2tvjrkZEREwwGzebNm+PThoeHTSAQMNu2bXOhQ3d9cP0YY0xdXZ256667Jv3dKbEHc+bMGR08eFBVVVUJ06uqqnTgwAGXuvKeEydOKBQKqbS0VKtXr9Y///lPt1vypO7ubvX19SVsT36/X7fffjvb0//z4osvKi8vT9dee62+/e1vq7+/P+nvSImAeffdd3X+/Hnl5+cnTM/Pz1dfX59LXXnLzTffrCeeeEIvvPCCduzYob6+Pi1evFgDAwNut+Y5F7YZtqexVVdX68knn9Sf//xn/exnP1NXV5c+85nPKBaLJfU9nrubejwfHL7BGDPqkA5Xourq6vjfFyxYoEWLFumjH/2odu3apYaGBhc78y62p7GtWrUq/vfy8nJVVlaqpKREzz//vFauXDnh70mJPZi5c+cqPT39ov9d+vv7L/pfCO+bM2eOFixYoBMnTrjdiudcuLrG9jRxBQUFKikpSXp7SomAycjI0MKFC9XR0ZEwvaOjQ4sXL3apK2+LxWJ64403VFBQ4HYrnlNaWqpgMJiwPZ05c0adnZ1sT2MYGBhQOBxOentKmUOkhoYG1dbWqrKyUosWLdL27dvV09OjdevWud2aJ3z/+99XTU2NiouL1d/fr4ceekjRaFR1dXVut+aKwcFBnTx5Mv66u7tbR44cUU5OjoqLi1VfX6/m5mbNmzdP8+bNU3Nzs2bPnq01a9a42PXUGW/95OTkaNOmTfryl7+sgoIC/etf/9IPf/hDzZ07VytWrEhuQZO+DjWFfvnLX5qSkhKTkZFhPvnJT5rOzk63W/KMVatWmYKCAjNz5kwTCoXMypUrzbFjx9xuyzX79u0zki6quro6Y8z7l6qbmppMMBg0fr/fLFmyxBw9etTdpqfQeOvn9OnTpqqqynz4wx82M2fONMXFxaaurs709PQkvRyGawBgTUqcgwGQmggYANYQMACsIWAAWEPAALCGgAFgDQEDwBoCBoA1BAwAawgYANYQMACsIWAAWPN/YUC+7PcNuXwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "observation,_=env.reset()\n",
    "gym_helper=GymHelper(env,figsize=(3,3))\n",
    "agent=PPO(env)\n",
    "for i in range(20):\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=agent.choose_action(observation)\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    done=terminated or truncated\n",
    "    time.sleep(0.5)\n",
    "    if done:\n",
    "        break\n",
    "gym_helper.render(title=\"finished\")\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "459a8b13",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
